import os.path as osp
import h5py
import numpy as np
import warnings
from tqdm import tqdm

import torch
import torch.nn.functional as F

from torch_geometric.data import Data
from torch_geometric.data import InMemoryDataset
from scipy.spatial import ConvexHull
from torch_geometric.nn import radius_graph
from Fragment_2025_01_23.mol_unit_sphere import Frame
from torch.utils.data import DataLoader
from PDB import get_pdb_info_EC, download_pdb_files

import os

class CustomData(Data):
    def __inc__(self, key, value, *args, **kwargs):
        # Adjust `mapping_a_to_b` based on the cumulative node count of graph B
        if key == "mapping_a_to_b" or key == "ch_b_edge_index":
            return self.num_nodes_b  # Offset by the number of nodes in B
        # Default behavior for other attributes
        return super().__inc__(key, value, *args, **kwargs)

one_letter_to_number = {
    "A": 1,  # Alanine
    "R": 2,  # Arginine
    "N": 3,  # Asparagine
    "D": 4,  # Aspartic acid
    "C": 5,  # Cysteine
    "E": 6,  # Glutamic acid
    "Q": 7,  # Glutamine
    "G": 8,  # Glycine
    "H": 9,  # Histidine
    "I": 10, # Isoleucine
    "L": 11, # Leucine
    "K": 12, # Lysine
    "M": 13, # Methionine
    "F": 14, # Phenylalanine
    "P": 15, # Proline
    "S": 16, # Serine
    "T": 17, # Threonine
    "W": 18, # Tryptophan
    "Y": 19, # Tyrosine
    "V": 20, # Valine
}

def list_files_in_folder(folder_path):
    """
    Lists all file names in the given folder.

    Args:
        folder_path (str): The path of the folder.

    Returns:
        list: A list of file names in the folder.
    """
    try:
        # List all files in the folder
        file_names = [f for f in os.listdir(folder_path) if os.path.isfile(os.path.join(folder_path, f))]
        return file_names
    except Exception as e:
        print(f"Error in listing files in folder: {e}")

class ECdataset(InMemoryDataset):
    def __init__(self,
                 root,
                 transform=None,
                 pre_transform=None,
                 pre_filter=None,
                 split='train'
                ):

        self.split = split
        self.root = root

        super(ECdataset, self).__init__(
            root, transform, pre_transform, pre_filter)

        self.transform, self.pre_transform, self.pre_filter = transform, pre_transform, pre_filter
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def processed_dir(self):
        name = 'processed'
        return osp.join(self.root, name, self.split)

    @property
    def raw_file_names(self):
        name = self.split + '.txt'
        return name

    @property
    def processed_file_names(self):
        return 'data.pt'


    def _normalize(self,tensor, dim=-1):
        '''
        Normalizes a `torch.Tensor` along dimension `dim` without `nan`s.
        '''
        return torch.nan_to_num(
            torch.div(tensor, torch.norm(tensor, dim=dim, keepdim=True)))

    def get_atom_pos(self, amino_types, atom_names, atom_amino_id, atom_pos):
        # atoms to compute side chain torsion angles: N, CA, CB, _G/_G1, _D/_D1, _E/_E1, _Z, NH1
        mask_n = np.char.equal(atom_names, b'N')
        mask_ca = np.char.equal(atom_names, b'CA')
        mask_c = np.char.equal(atom_names, b'C')
        mask_cb = np.char.equal(atom_names, b'CB')
        mask_g = np.char.equal(atom_names, b'CG') | np.char.equal(atom_names, b'SG') | np.char.equal(atom_names, b'OG') | np.char.equal(atom_names, b'CG1') | np.char.equal(atom_names, b'OG1')
        mask_d = np.char.equal(atom_names, b'CD') | np.char.equal(atom_names, b'SD') | np.char.equal(atom_names, b'CD1') | np.char.equal(atom_names, b'OD1') | np.char.equal(atom_names, b'ND1')
        mask_e = np.char.equal(atom_names, b'CE') | np.char.equal(atom_names, b'NE') | np.char.equal(atom_names, b'OE1')
        mask_z = np.char.equal(atom_names, b'CZ') | np.char.equal(atom_names, b'NZ')
        mask_h = np.char.equal(atom_names, b'NH1')

        pos_n = np.full((len(amino_types),3),np.nan)
        pos_n[atom_amino_id[mask_n]] = atom_pos[mask_n]
        pos_n = torch.FloatTensor(pos_n)

        pos_ca = np.full((len(amino_types),3),np.nan)
        pos_ca[atom_amino_id[mask_ca]] = atom_pos[mask_ca]
        pos_ca = torch.FloatTensor(pos_ca)

        pos_c = np.full((len(amino_types),3),np.nan)
        pos_c[atom_amino_id[mask_c]] = atom_pos[mask_c]
        pos_c = torch.FloatTensor(pos_c)

        # if data only contain pos_ca, we set the position of C and N as the position of CA
        pos_n[torch.isnan(pos_n)] = pos_ca[torch.isnan(pos_n)]
        pos_c[torch.isnan(pos_c)] = pos_ca[torch.isnan(pos_c)]

        pos_cb = np.full((len(amino_types),3),np.nan)
        pos_cb[atom_amino_id[mask_cb]] = atom_pos[mask_cb]
        pos_cb = torch.FloatTensor(pos_cb)

        pos_g = np.full((len(amino_types),3),np.nan)
        pos_g[atom_amino_id[mask_g]] = atom_pos[mask_g]
        pos_g = torch.FloatTensor(pos_g)

        pos_d = np.full((len(amino_types),3),np.nan)
        pos_d[atom_amino_id[mask_d]] = atom_pos[mask_d]
        pos_d = torch.FloatTensor(pos_d)

        pos_e = np.full((len(amino_types),3),np.nan)
        pos_e[atom_amino_id[mask_e]] = atom_pos[mask_e]
        pos_e = torch.FloatTensor(pos_e)

        pos_z = np.full((len(amino_types),3),np.nan)
        pos_z[atom_amino_id[mask_z]] = atom_pos[mask_z]
        pos_z = torch.FloatTensor(pos_z)

        pos_h = np.full((len(amino_types),3),np.nan)
        pos_h[atom_amino_id[mask_h]] = atom_pos[mask_h]
        pos_h = torch.FloatTensor(pos_h)

        return pos_n, pos_ca, pos_c, pos_cb, pos_g, pos_d, pos_e, pos_z, pos_h


    def side_chain_embs(self, pos_n, pos_ca, pos_c, pos_cb, pos_g, pos_d, pos_e, pos_z, pos_h):
        v1, v2, v3, v4, v5, v6, v7 = pos_ca - pos_n, pos_cb - pos_ca, pos_g - pos_cb, pos_d - pos_g, pos_e - pos_d, pos_z - pos_e, pos_h - pos_z

        # five side chain torsion angles
        # We only consider the first four torsion angles in side chains since only the amino acid arginine has five side chain torsion angles, and the fifth angle is close to 0.
        angle1 = torch.unsqueeze(self.compute_dihedrals(v1, v2, v3),1)
        angle2 = torch.unsqueeze(self.compute_dihedrals(v2, v3, v4),1)
        angle3 = torch.unsqueeze(self.compute_dihedrals(v3, v4, v5),1)
        angle4 = torch.unsqueeze(self.compute_dihedrals(v4, v5, v6),1)
        angle5 = torch.unsqueeze(self.compute_dihedrals(v5, v6, v7),1)

        side_chain_angles = torch.cat((angle1, angle2, angle3, angle4),1)
        side_chain_embs = torch.cat((torch.sin(side_chain_angles), torch.cos(side_chain_angles)),1)

        return side_chain_embs


    def bb_embs(self, X):
        # X should be a num_residues x 3 x 3, order N, C-alpha, and C atoms of each residue
        # N coords: X[:,0,:]
        # CA coords: X[:,1,:]
        # C coords: X[:,2,:]
        # return num_residues x 6
        # From https://github.com/jingraham/neurips19-graph-protein-design

        X = torch.reshape(X, [3 * X.shape[0], 3])
        dX = X[1:] - X[:-1]
        U = self._normalize(dX, dim=-1)
        u0 = U[:-2]
        u1 = U[1:-1]
        u2 = U[2:]

        angle = self.compute_dihedrals(u0, u1, u2)

        # add phi[0], psi[-1], omega[-1] with value 0
        angle = F.pad(angle, [1, 2])
        angle = torch.reshape(angle, [-1, 3])
        angle_features = torch.cat([torch.cos(angle), torch.sin(angle)], 1)
        return angle_features


    def compute_dihedrals(self, v1, v2, v3):
        n1 = torch.cross(v1, v2)
        n2 = torch.cross(v2, v3)
        a = (n1 * n2).sum(dim=-1)
        b = torch.nan_to_num((torch.cross(n1, n2) * v2).sum(dim=-1) / v2.norm(dim=1))
        torsion = torch.nan_to_num(torch.atan2(b, a))
        return torsion

    def pdb_info_to_graph(self, df, edges, s_graph_dict):
        data = CustomData()

        ca_values = df['CA'].values
        data.coords_a_ca = torch.tensor(np.vstack(ca_values), dtype=torch.float32)
        pos_n = df['N'].values
        pos_n = torch.tensor(np.vstack(pos_n), dtype=torch.float32)
        pos_ca = torch.tensor(np.vstack(ca_values), dtype=torch.float32)
        pos_c = df['C'].values
        pos_c = torch.tensor(np.vstack(pos_c), dtype=torch.float32)
        fea_a = df[['asa','phi', 'psi', 'NH_O_1_relidx',
                    'NH_O_1_energy', 'O_NH_1_relidx',
                    'O_NH_1_energy', 'NH_O_2_relidx',
                    'NH_O_2_energy', 'O_NH_2_relidx',
                    'O_NH_2_energy']].values
        data.side_chain_embs =torch.tensor(np.vstack(fea_a), dtype=torch.float32)
        amino_types = df['aa'].apply(lambda x: one_letter_to_number.get(x, -1)).values
        data.x = torch.tensor(amino_types, dtype=torch.float32)

        # three backbone torsion angles
        bb_embs = self.bb_embs(torch.cat((torch.unsqueeze(pos_n,1), torch.unsqueeze(pos_ca,1), torch.unsqueeze(pos_c,1)),1))
        bb_embs[torch.isnan(bb_embs)] = 0
        data.bb_embs = bb_embs
        data.coords_a_n = pos_n
        data.coords_a_c = pos_c

        # add edges
        data.edge_index = torch.tensor(edges, dtype=torch.long).contiguous()

        # add secondary structure graph
        coords_b_ = torch.tensor(s_graph_dict['coords'], dtype=torch.float32)
        data.coords_b_ = coords_b_

        ss_cls = torch.tensor(s_graph_dict['ss_num'], dtype=torch.float32)
        data.ss_x = ss_cls

        num_nodes_b = len(coords_b_)
        data.num_nodes_b = num_nodes_b

        # mapping_a_to_b
        mapping_a_to_b = torch.tensor(df['ss_ser'].values, dtype=torch.long).contiguous()
        data.mapping_a_to_b = mapping_a_to_b


        assert len(data.x)==len(data.coords_a_ca)==len(data.coords_a_n)==len(data.coords_a_c)==len(data.side_chain_embs)==len(data.bb_embs)

        return data

    def process(self):
        print('Beginning Processing ...')

        # Load the file with the list of functions.
        functions_ = []
        with open(self.root+"/unique_functions.txt", 'r') as mFile:
            for line in mFile:
                functions_.append(line.rstrip())

        # Get the file list.
        if self.split == "Train":
            splitFile = "/training.txt"
        elif self.split == "Val":
            splitFile = "/validation.txt"
        elif self.split == "Test":
            splitFile = "/testing.txt"

        proteinNames_ = []
        fileList_ = []
        with open(self.root+splitFile, 'r') as mFile:
            for line in mFile:
                proteinNames_.append(line.rstrip())
                fileList_.append(self.root+"/data/"+line.rstrip())

        # Load the functions.
        print("Reading protein functions")
        protFunct_ = {}
        with open(self.root+"/chain_functions.txt", 'r') as mFile:
            for line in mFile:
                splitLine = line.rstrip().split(',')
                if splitLine[0] in proteinNames_:
                    protFunct_[splitLine[0]] = int(splitLine[1])

        # Load the dataset
        print('Downloading PDB files...')
        folder_path = self.root+'/pdb/'
        if not os.path.exists(folder_path):
            os.makedirs(folder_path)
        for fileIter, curFile in enumerate(tqdm(fileList_, desc="Downloading PDBs for {} set".format(self.split))):
            fileName = curFile.split('/')[-1]
            pdb_id, case_id = fileName.split('.')
            download_pdb_files(pdb_id, folder_path)
            # print('Downloaded {} / {} proteins'.format(fileIter+1, len(fileList_)))
        file_list = list_files_in_folder(folder_path)

        print("Reading the data")
        count = 0
        frame = Frame()
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            data_list = []
            k = 0
            for fileIter, curFile in enumerate(tqdm(fileList_, desc="Processing proteins for {} set".format(self.split))):
                fileName = curFile.split('/')[-1]
                # debug
                pdb_id, case_id = fileName.split('.')
                if pdb_id+'.pdb' not in file_list:
                    continue
                # try:
                file_path = folder_path + pdb_id + '.pdb'
                df, edges, s_graph_dict = get_pdb_info_EC(file_path, case_id)
                curProtein = self.pdb_info_to_graph(df, edges, s_graph_dict)
                curProtein.id = fileName
                curProtein.y = torch.tensor(protFunct_[proteinNames_[fileIter]])

                # schull edges
                pos = curProtein.coords_b_
                _, shell_data_ch, edge_index_hull = frame.get_frame(pos.numpy())
                ch_pos = torch.tensor(shell_data_ch, dtype=torch.float)
                ch_r = torch.norm(ch_pos - torch.mean(ch_pos, dim=0), dim=-1)

                curProtein['ch_b_pos'] = ch_pos
                curProtein['ch_b_r'] = ch_r
                curProtein['ch_b_edge_index'] = torch.tensor(edge_index_hull, dtype=torch.long)

                if not curProtein.x is None:
                    data_list.append(curProtein)
                    count += 1
                # except:
                #    continue
                k += 1
                print('Processed {} / {} proteins'.format(k, len(fileList_)))
                if count == 32:
                    break
        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])
        print('Done!')
        print('Ratio of proteins processed: {:.2f}%'.format(count/len(fileList_)*100))
